
## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window, types
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.context import SQLContext
import pyspark.sql.functions as f

from math import floor, ceil
import sys
import pandas as pd
import json

from splink.spark.spark_linker import SparkLinker
from splink.spark.spark_comparison_library import jaro_winkler_at_thresholds, levenshtein_at_thresholds, exact_match
from splink.charts import save_offline_chart

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("splink_sandbox/temp_graphframes/")
spark.udf.registerJavaFunction('jaro_winkler','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())

## LOAD DATA FROM SLURM PARAMS

state = ["AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "DC", "FL", "GA", "HI",
"ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO",
"MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI",
"SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"][int(sys.argv[1])].lower()

cl = spark.read.parquet(f"cl_in/{state}")

cl = cl.filter(f"po_num IS NOT NULL")
cl = cl.repartition(int(ceil(cl.count()/1000)))

print(f"There are {cl.count()} cl rows.")

## WRITE MODEL SETTINGS

no_comp = " AND abs(l.gender - r.gender) < 1.001 "\
  + "AND NOT (l.fips_st = r.fips_st AND l.parcel_id = r.parcel_id) "\
  + "AND NOT (l.address_type = 's' AND r.address_type = 's') "\
  + "AND abs(l.vintage - r.vintage) < 4.001 "

br1 = "l.zip = r.zip AND l.first_only = r.first_only" + no_comp
br2 = "l.zip = r.zip AND l.last_only = r.last_only" + no_comp
br3 = "l.last_only = r.last_only AND l.first_only = r.first_only AND l.city = r.city" + no_comp
br4 = "l.po_num = r.po_num AND l.first_only = r.first_only" + no_comp
br5 = "l.po_num = r.po_num AND l.last_only = r.last_only" + no_comp

settings = {
    
    "link_type": "dedupe_only",
    
    "blocking_rules_to_generate_predictions": [
        {"blocking_rule": br1, "salting_partitions": 20},
        {"blocking_rule": br2, "salting_partitions": 20},
        {"blocking_rule": br3, "salting_partitions": 20},
        {"blocking_rule": br4, "salting_partitions": 20},
        {"blocking_rule": br5, "salting_partitions": 20},
    ],
    
    "comparisons": [
        jaro_winkler_at_thresholds(
          "first_m",
          [0.94, 0.88],
          include_exact_match_level = False,
          term_frequency_adjustments = True
          ),
        jaro_winkler_at_thresholds(
          "last",
          [0.94, 0.88],
          include_exact_match_level = False,
          term_frequency_adjustments = True
          ),
        exact_match(
          "po_num"
          )
    ],
    
    "retain_matching_columns": True,
    "retain_intermediate_calculation_columns": False,
    "unique_id_column_name": "uid",
    "additional_columns_to_retain": ["zip"],
    "max_iterations": 40
    
}

## TRAIN THE MODEL

print(f"Training this state PO box model with {cl.count()} rows.")
linker = SparkLinker(input_table_or_tables = cl, settings_dict = settings, spark = spark)

linker.estimate_u_using_random_sampling(target_rows = cl.count())

linker.estimate_parameters_using_expectation_maximisation(
  br1,
  comparisons_to_deactivate = ["first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
      ._get_comparison_by_output_column_name("first_m") \
      ._get_comparison_level_by_comparison_vector_value(1)
    ],
  fix_u_probabilities=False
)

linker.estimate_parameters_using_expectation_maximisation(
  br2,
  comparisons_to_deactivate = ["last"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
      ._get_comparison_by_output_column_name("last") \
      ._get_comparison_level_by_comparison_vector_value(1)
    ],
  fix_u_probabilities=False
)

linker.estimate_parameters_using_expectation_maximisation(
  br3,
  comparisons_to_deactivate = ["last", "first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
      ._get_comparison_by_output_column_name("last") \
      ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
      ._get_comparison_by_output_column_name("first_m") \
      ._get_comparison_level_by_comparison_vector_value(1)
    ],
  fix_u_probabilities=False
)

linker.save_settings_to_json(f"logs/{state}_p.json", overwrite = True)
save_offline_chart(linker.match_weights_chart().spec, f"logs/{state}_p.html", overwrite = True)
  
## PREDICT MATCHES AND SAVE

df = linker.predict(threshold_match_probability = 0.5) \
  .as_spark_dataframe() \
  .filter("zip_l = zip_r OR zip_l IS NULL OR zip_r IS NULL") \
  .filter("(gamma_last > 0 AND gamma_first_m > 0) OR (gamma_last = -1 AND gamma_first_m = 3) OR (gamma_last = 3 AND gamma_first_m = -1)") \
  .select("match_probability", "uid_l", "uid_r", "gamma_last", "gamma_first_m", "gamma_po_num")

print(f"There are {df.count()} predicted matches with posterior >= 0.5")
df.coalesce(10).write.mode("overwrite").parquet(f"data/cl_dedupe_matches/state={state}/type=p/")
